import random
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid


"""
Utilities to help with visualizing images and other data
"""


def tensor_to_image(tensor):
  """
  Convert a torch tensor into a numpy ndarray for visualization.

  Inputs:
  - tensor: A torch tensor of shape (3, H, W) with elements in the range [0, 1]

  Returns:
  - ndarr: A uint8 numpy array of shape (H, W, 3)
  """
  tensor = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0)
  ndarr = tensor.to('cpu', torch.uint8).numpy()
  return ndarr


def visualize_dataset(X_data, y_data, samples_per_class, class_list):
  """
  Make a grid-shape image to plot

  Inputs:
  - X_data: set of [batch, 3, width, height] data
  - y_data: paired label of X_data in [batch] shape
  - samples_per_class: number of samples want to present
  - class_list: list of class names
    e.g.) ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

  Outputs:
  - An grid-image that visualize samples_per_class number of samples per class
  """
  img_half_width = X_data.shape[2] // 2
  samples = []
  for y, cls in enumerate(class_list):
    tx = -4
    ty = (img_half_width * 2 + 2) * y + (img_half_width + 2)
    plt.text(tx, ty, cls, ha='right')
    idxs = (y_data == y).nonzero().view(-1)
    for i in range(samples_per_class):
      idx = idxs[random.randrange(idxs.shape[0])].item()
      samples.append(X_data[idx])

  img = make_grid(samples, nrow=samples_per_class)
  return tensor_to_image(img)
